"""Token splitter."""

import itertools
import logging
import re
from typing import Callable, Generic, List, Pattern, Tuple, TypeVar

from pydantic import BaseModel, Field, PrivateAttr

from docreader.splitter.header_hook import (
    HeaderTracker,
)
from docreader.utils.split import split_by_char, split_by_sep

DEFAULT_CHUNK_OVERLAP = 100
DEFAULT_CHUNK_SIZE = 512

T = TypeVar("T")

logger = logging.getLogger(__name__)


class TextSplitter(BaseModel, Generic[T]):
    chunk_size: int = Field(description="The token chunk size for each chunk.")
    chunk_overlap: int = Field(
        description="The token overlap of each chunk when splitting."
    )
    separators: List[str] = Field(
        description="Default separators for splitting into words"
    )

    # Try to keep the matched characters as a whole.
    # If it's too long, the content will be further segmented.
    protected_regex: List[str] = Field(
        description="Protected regex for splitting into words"
    )
    len_function: Callable[[str], int] = Field(description="The length function.")
    # Header tracking Hook related attributes
    header_hook: HeaderTracker = Field(default_factory=HeaderTracker, exclude=True)

    _protected_fns: List[Pattern] = PrivateAttr()
    _split_fns: List[Callable] = PrivateAttr()

    def __init__(
        self,
        chunk_size: int = DEFAULT_CHUNK_SIZE,
        chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
        separators: List[str] = ["\n", "。", " "],
        protected_regex: List[str] = [
            # math formula
            r"\$\$[\s\S]*?\$\$",
            # image
            r"!\[.*?\]\(.*?\)",
            # link
            r"\[.*?\]\(.*?\)",
            # table header
            r"(?:\|[^|\n]*)+\|[\r\n]+\s*(?:\|\s*:?-{3,}:?\s*)+\|[\r\n]+",
            # table body
            r"(?:\|[^|\n]*)+\|[\r\n]+",
            # code header
            r"```(?:\w+)[\r\n]+[^\r\n]*",
        ],
        length_function: Callable[[str], int] = lambda x: len(x),
    ):
        """Initialize with parameters."""
        if chunk_overlap > chunk_size:
            raise ValueError(
                f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
                f"({chunk_size}), should be smaller."
            )

        super().__init__(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=separators,
            protected_regex=protected_regex,
            len_function=length_function,
        )
        self._protected_fns = [re.compile(reg) for reg in protected_regex]
        self._split_fns = [split_by_sep(sep) for sep in separators] + [split_by_char()]

    def split_text(self, text: str) -> List[Tuple[int, int, str]]:
        """Split text into chunks."""
        if text == "":
            return []

        splits = self._split(text)
        protect = self._split_protected(text)
        splits = self._join(splits, protect)

        assert "".join(splits) == text

        chunks = self._merge(splits)
        return chunks

    def _split(self, text: str) -> List[str]:
        """Break text into splits that are smaller than chunk size.

        NOTE: the splits contain the separators.
        """
        if self.len_function(text) <= self.chunk_size:
            return [text]

        splits = []
        for split_fn in self._split_fns:
            splits = split_fn(text)
            if len(splits) > 1:
                break

        new_splits = []
        for split in splits:
            split_len = self.len_function(split)
            if split_len <= self.chunk_size:
                new_splits.append(split)
            else:
                # recursively split
                new_splits.extend(self._split(split))
        return new_splits

    def _merge(self, splits: List[str]) -> List[Tuple[int, int, str]]:
        """Merge splits into chunks.

        The high-level idea is to keep adding splits to a chunk until we
        exceed the chunk size, then we start a new chunk with overlap.

        When we start a new chunk, we pop off the first element of the previous
        chunk until the total length is less than the chunk size.
        """
        chunks: List[Tuple[int, int, str]] = []

        cur_chunk: List[Tuple[int, int, str]] = []

        cur_headers, cur_len = "", 0
        cur_start, cur_end = 0, 0
        for split in splits:
            cur_end = cur_start + len(split)
            split_len = self.len_function(split)
            if split_len > self.chunk_size:
                logger.error(
                    f"Got a split of size {split_len}, ",
                    f"larger than chunk size {self.chunk_size}.",
                )

            self.header_hook.update(split)
            cur_headers = self.header_hook.get_headers()
            cur_headers_len = self.len_function(cur_headers)

            if cur_headers_len > self.chunk_size:
                logger.error(
                    f"Got headers of size {cur_headers_len}, ",
                    f"larger than chunk size {self.chunk_size}.",
                )
                cur_headers, cur_headers_len = "", 0

            # if we exceed the chunk size after adding the new split, then
            # we need to end the current chunk and start a new one
            if cur_len + split_len + cur_headers_len > self.chunk_size:
                # end the previous chunk
                if len(cur_chunk) > 0:
                    chunks.append(
                        (
                            cur_chunk[0][0],
                            cur_chunk[-1][1],
                            "".join([c[2] for c in cur_chunk]),
                        )
                    )

                # start a new chunk with overlap
                # keep popping off the first element of the previous chunk until:
                #   1. the current chunk length is less than chunk overlap
                #   2. the total length is less than chunk size
                while cur_chunk and (
                    cur_len > self.chunk_overlap
                    or cur_len + split_len + cur_headers_len > self.chunk_size
                ):
                    # pop off the first element
                    first_chunk = cur_chunk.pop(0)
                    cur_len -= self.len_function(first_chunk[2])

                if (
                    cur_headers
                    and split_len + cur_headers_len < self.chunk_size
                    and cur_headers not in split
                ):
                    cur_chunk.insert(
                        0,
                        (
                            cur_chunk[0][0] if cur_chunk else cur_start,
                            cur_chunk[0][1] if cur_chunk else cur_end,
                            cur_headers,
                        ),
                    )
                    cur_len += cur_headers_len

            cur_chunk.append((cur_start, cur_end, split))
            cur_len += split_len
            cur_start = cur_end

        # handle the last chunk
        assert cur_chunk
        chunks.append(
            (
                cur_chunk[0][0],
                cur_chunk[-1][1],
                "".join([c[2] for c in cur_chunk]),
            )
        )

        return chunks

    def _split_protected(self, text: str) -> List[Tuple[int, str]]:
        matches = [
            (match.start(), match.end())
            for pattern in self._protected_fns
            for match in pattern.finditer(text)
        ]
        matches.sort(key=lambda x: (x[0], -x[1]))

        res = []

        def fold(initial: int, current: Tuple[int, int]) -> int:
            if current[0] >= initial:
                if current[1] - current[0] < self.chunk_size:
                    res.append((current[0], text[current[0] : current[1]]))
                else:
                    logger.warning(f"Protected text ignore: {current}")
            return max(initial, current[1])

        # filter overlapping matches
        list(itertools.accumulate(matches, fold, initial=-1))
        return res

    def _join(self, splits: List[str], protect: List[Tuple[int, str]]) -> List[str]:
        """
        Merges and splits elements in splits array based on protected substrings.

        The function processes the input splits to ensure all protected substrings
        remain as single items. If a protected substring is concatenated with preceding
        or following content in any split element, it will be separated from
        the adjacent content. The final result maintains the original order of content
        while enforcing the integrity of protected substrings.

        Key behaviors:
        1. Preserves the complete structure of each protected substring
        2. Separates protected substrings from any adjacent non-protected content
        3. Maintains the original sequence of all content except for necessary
        4. Handles cases where protected substrings are partially concatenated
        """
        j = 0
        point, start = 0, 0
        res = []

        for split in splits:
            end = start + len(split)

            cur = split[point - start :]
            while j < len(protect):
                p_start, p_content = protect[j]
                p_end = p_start + len(p_content)

                if end <= p_start:
                    break

                if point < p_start:
                    local_end = p_start - point
                    res.append(cur[:local_end])
                    cur = cur[local_end:]
                    point = p_start

                res.append(p_content)
                j += 1

                if point < p_end:
                    local_start = p_end - point
                    cur = cur[local_start:]
                    point = p_end

                if not cur:
                    break

            if cur:
                res.append(cur)
                point = end

            start = end
        return res


if __name__ == "__main__":
    s = """
    这是一些普通文本。

    | 姓名 | 年龄 | 城市 |
    |------|------|------|
    | 张三 | 25   | 北京 |
    | 李四 | 30   | 上海 |
    | 王五 | 28   | 广州 |
    | 张三 | 25   | 北京 |
    | 李四 | 30   | 上海 |
    | 王五 | 28   | 广州 |

    这是文本结束。

"""

    sp = TextSplitter(chunk_size=200, chunk_overlap=2)
    ck = sp.split_text(s)
    for c in ck:
        print("------", len(c))
        print(c)
    pass
